import argparse
from time import time
import math
import random
import numpy as np

import os

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets

import src as models
from utils.losses import LabelSmoothingCrossEntropy

model_names = sorted(name for name in models.__dict__
                     if name.islower() and not name.startswith("_")
                     and callable(models.__dict__[name]))

best_acc1 = 0

DATASETS = {
    'cifar10': {
        'num_classes': 10,
        'img_size': 32,
        'mean': [0.4914, 0.4822, 0.4465],
        'std': [0.2470, 0.2435, 0.2616]
    },
    'cifar100': {
        'num_classes': 100,
        'img_size': 32,
        'mean': [0.5071, 0.4867, 0.4408],
        'std': [0.2675, 0.2565, 0.2761]
    }
}


def init_parser():
    parser = argparse.ArgumentParser(description='CIFAR quick training script')

    # Data args
    #parser.add_argument('data', metavar='DIR', help='path to dataset')
    parser.add_argument('--dataset', type=str.lower, choices=['cifar10', 'cifar100'], default='cifar10')
    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)')
    parser.add_argument('--print-freq', default=10, type=int, metavar='N', help='log frequency (by iteration)')
    parser.add_argument('--checkpoint-path', type=str, default='checkpoint.pth', help='path to checkpoint (default: checkpoint.pth)')

    # Optimization hyperparams
    parser.add_argument('--epochs', default=200, type=int, metavar='N', help='number of total epochs to run')
    parser.add_argument('--warmup', default=5, type=int, metavar='N', help='number of warmup epochs')
    parser.add_argument('-b', '--batch-size', default=128, type=int, metavar='N', help='mini-batch size (default: 128)', dest='batch_size')
    parser.add_argument('--lr', default=0.0005, type=float, help='initial learning rate')
    parser.add_argument('--weight-decay', default=3e-2, type=float, help='weight decay (default: 1e-4)')
    parser.add_argument('--clip-grad-norm', default=0., type=float, help='gradient norm clipping (default: 0 (disabled))')
    
    # Model
    parser.add_argument('-m', '--model', type=str.lower, choices=model_names, default='noisyvit_lite_7', dest='model')
    parser.add_argument('-p', '--positional-embedding', type=str.lower, choices=['learnable', 'sine', 'none'], default='learnable', dest='positional_embedding')
    parser.add_argument('--conv-layers', default=0, type=int, help='number of convolutional layers (cct only)')
    parser.add_argument('--conv-size', default=0, type=int, help='convolution kernel size (cct only)')
    parser.add_argument('--patch-size', default=4, type=int, help='image patch size (vit and cvt only)')
    parser.add_argument('--disable-cos', action='store_true', help='disable cosine lr schedule')
    parser.add_argument('--disable-aug', action='store_true', help='disable augmentation policies for training')

    
    parser.add_argument('--dropout_rate', default=0.1, type=float)
    parser.add_argument('--attention_dropout', default=0.1, type=float)       
    parser.add_argument('--stochastic_depth', default=0.1, type=float)     
    
    #Noisy Feature Mixup
    parser.add_argument('--alpha', default=0, type=float)
    parser.add_argument('--manifold_mixup', type=int, default=0, metavar='S', help='manifold mixup (default: 0)')
    parser.add_argument('--add_noise', default=0, type=float)
    parser.add_argument('--mult_noise', default=0, type=float)    
    
    # Others
    parser.add_argument('--gpu-id', default=0, type=int)
    parser.add_argument('--no-cuda', action='store_true', help='disable cuda')
    parser.add_argument('--seed', default=1, type=int)

    return parser

def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


def main():
    global best_acc1
    

    parser = init_parser()
    args = parser.parse_args()
    
    seed_everything(args.seed)

    
    img_size = DATASETS[args.dataset]['img_size']
    num_classes = DATASETS[args.dataset]['num_classes']
    img_mean, img_std = DATASETS[args.dataset]['mean'], DATASETS[args.dataset]['std']

    model = models.__dict__[args.model](img_size=img_size,
                                        num_classes=num_classes,
                                        positional_embedding=args.positional_embedding,
                                        n_conv_layers=args.conv_layers,
                                        kernel_size=args.conv_size,
                                        patch_size=args.patch_size,                  
                                        #dropout_rate = args.dropout_rate,
                                        #attention_dropout = args.attention_dropout,
                                        #stochastic_depth  = args.stochastic_depth,                    
                                        #add_noise=args.add_noise, 
                                        #mult_noise=args.mult_noise
                                        )

    #criterion = LabelSmoothingCrossEntropy()
    criterion = torch.nn.CrossEntropyLoss()
    

    if (not args.no_cuda) and torch.cuda.is_available():
        torch.cuda.set_device(args.gpu_id)
        model.cuda(args.gpu_id)
        criterion = criterion.cuda(args.gpu_id)

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    normalize = [transforms.Normalize(mean=img_mean, std=img_std)]

    augmentations = []
    if not args.disable_aug:
        from utils.autoaug import CIFAR10Policy
        augmentations += [
            CIFAR10Policy()
        ]
    augmentations += [
        transforms.RandomCrop(img_size, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        *normalize,
    ]

    augmentations = transforms.Compose(augmentations)
    train_dataset = datasets.__dict__[args.dataset.upper()](
        root='./cifar/', train=True, download=True, transform=augmentations)

    val_dataset = datasets.__dict__[args.dataset.upper()](
        root='./cifar/', train=False, download=False, transform=transforms.Compose([
            transforms.Resize(img_size),
            transforms.ToTensor(),
            *normalize,
        ]))

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers)

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers)

    print("Beginning training")
    time_begin = time()
    for epoch in range(args.epochs):
        adjust_learning_rate(optimizer, epoch, args)
        cls_train(train_loader, model, criterion, optimizer, epoch, args)
        acc1 = cls_validate(val_loader, model, criterion, args, epoch=epoch, time_begin=time_begin)
        best_acc1 = max(acc1, best_acc1)

    total_mins = (time() - time_begin) / 60
    print(f'Script finished in {total_mins:.2f} minutes, '
          f'best top-1: {best_acc1:.2f}, '
          f'final top-1: {acc1:.2f}')
    #torch.save(model.state_dict(), args.checkpoint_path)

    DESTINATION_PATH = args.dataset + '_models2/'
    OUT_DIR = os.path.join(DESTINATION_PATH, f'arch_{args.model}_alpha_{args.alpha}_manimixup_{args.manifold_mixup}_addn_{args.add_noise}_multn_{args.mult_noise}_seed_{args.seed}')
    
    if not os.path.isdir(DESTINATION_PATH):
            os.mkdir(DESTINATION_PATH)
    torch.save(model, OUT_DIR+'.pt')


def adjust_learning_rate(optimizer, epoch, args):
    lr = args.lr
    if hasattr(args, 'warmup') and epoch < args.warmup:
        lr = lr / (args.warmup - epoch)
    elif not args.disable_cos:
        lr *= 0.5 * (1. + math.cos(math.pi * (epoch - args.warmup) / (args.epochs - args.warmup)))

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def accuracy(output, target):
    with torch.no_grad():
        batch_size = target.size(0)

        _, pred = output.topk(1, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        correct_k = correct[:1].flatten().float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
        return res




def _noise(x, add_noise_level=0.0, mult_noise_level=0.0, sparsity_level=0.0):
    add_noise = 0.0
    mult_noise = 1.0
    with torch.cuda.device(0):
        if add_noise_level > 0.0:
            add_noise = add_noise_level * np.random.beta(2, 5) * torch.cuda.FloatTensor(x.shape).normal_()
            #torch.clamp(add_noise, min=-(2*var), max=(2*var), out=add_noise) # clamp
        if mult_noise_level > 0.0:
            mult_noise = mult_noise_level * np.random.beta(2, 5) * (2*torch.cuda.FloatTensor(x.shape).uniform_()-1) + 1 
    return mult_noise * x + add_noise      

def do_noisy_mixup(x, y, alpha=0.0, add_noise_level=0.0, mult_noise_level=0.0):
    lam = np.random.beta(alpha, alpha) if alpha > 0.0 else 1.0
    index = torch.randperm(x.size()[0]).cuda()
    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    return _noise(mixed_x, add_noise_level=add_noise_level, mult_noise_level=mult_noise_level), y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)



def cls_train(train_loader, model, criterion, optimizer, epoch, args):
    model.train()
    for i, (images, target) in enumerate(train_loader):
        if (not args.no_cuda) and torch.cuda.is_available():
            images = images.cuda(args.gpu_id, non_blocking=True)
            target = target.cuda(args.gpu_id, non_blocking=True)
        
        #print(images.max())
        #print(images.min())

       
        k = 0 if args.alpha > 0.0 else -1
        if args.alpha > 0.0 and args.manifold_mixup == True: k = np.random.choice(range(6), 1)[0]
        
        
        if args.alpha == 0.0:
            output = model(images)
            loss = criterion(output, target)        
            
        elif k==0 or k==5:
            inputs, targets_a, targets_b, lam = do_noisy_mixup(images, target, alpha=args.alpha, add_noise_level=args.add_noise, mult_noise_level=args.mult_noise)
            inputs, targets_a, targets_b = map(Variable, (inputs, targets_a, targets_b))
            output = model(inputs)
            loss = mixup_criterion(criterion, output, targets_a, targets_b, lam)  
        
        else:
            lam = np.random.beta(args.alpha, args.alpha)
            output = model(images, lam, k, args.add_noise, args.mult_noise)
                
            targets_a = target
            index = torch.flip(torch.arange(0,images.size()[0]), [0]).cuda()
            targets_b = targets_a[index]
            
            loss = mixup_criterion(criterion, output, targets_a, targets_b, lam)  

        optimizer.zero_grad()
        loss.backward()
        
        if args.clip_grad_norm > 0:
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_grad_norm, norm_type=2)

        optimizer.step()


def cls_validate(val_loader, model, criterion, args, epoch=None, time_begin=None):
    model.eval()
    loss_val, acc1_val = 0, 0
    n = 0
    with torch.no_grad():
        for i, (images, target) in enumerate(val_loader):
            if (not args.no_cuda) and torch.cuda.is_available():
                images = images.cuda(args.gpu_id, non_blocking=True)
                target = target.cuda(args.gpu_id, non_blocking=True)

            output = model(images)
            loss = criterion(output, target)

            acc1 = accuracy(output, target)
            n += images.size(0)
            loss_val += float(loss.item() * images.size(0))
            acc1_val += float(acc1[0] * images.size(0))

            if args.print_freq >= 0 and i % args.print_freq == 0:
                avg_loss, avg_acc1 = (loss_val / n), (acc1_val / n)
                #print(f'[Epoch {epoch + 1}][Eval][{i}] \t Loss: {avg_loss:.4e} \t Top-1 {avg_acc1:6.2f}')

    avg_loss, avg_acc1 = (loss_val / n), (acc1_val / n)
    total_mins = -1 if time_begin is None else (time() - time_begin) / 60
    print(f'[Epoch {epoch + 1}] \t \t Top-1 {avg_acc1:6.2f} \t \t Time: {total_mins:.2f}')

    return avg_acc1


if __name__ == '__main__':
    main()
